#!/usr/bin/env python3
"""
Phase 1: Panel Composition Statistics (69 vs 140 estimation samples)
Produces Table 1: Panel composition by region, income, OECD status.

The 69-country sample = EBA(49) + SSA(20).
The 140-country sample = 69 + EU_EXPANSION(10) + EXPANSION_TIER1(~61).
Both draw from the same underlying 237-country full_panel.csv.
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "followup" / "src"))
from macro import EBA_COUNTRIES, SSA_COUNTRIES, EU_EXPANSION, EXPANSION_TIER1

FOLLOWUP_DIR = PROJECT_DIR / "multilateral" / "followup"
FRAG_DIR = PROJECT_DIR / "fragility"
TABLE_DIR = FRAG_DIR / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

# ── Load panel ─────────────────────────────────────────────────────────────
print("=" * 70)
print("PHASE 1: PANEL COMPOSITION STATISTICS")
print("=" * 70)

panel = pd.read_csv(FOLLOWUP_DIR / "data" / "processed" / "full_panel.csv")
panel = panel[(panel['year'] >= 1986) & (panel['year'] <= 2024)].copy()

# Define estimation samples
countries_69 = set(EBA_COUNTRIES + SSA_COUNTRIES)
countries_140 = set(EBA_COUNTRIES + SSA_COUNTRIES + EU_EXPANSION + EXPANSION_TIER1)
countries_added = countries_140 - countries_69

fp69 = panel[panel['iso3'].isin(countries_69)].copy()
fp140 = panel[panel['iso3'].isin(countries_140)].copy()

print(f"69-country estimation sample: {fp69['iso3'].nunique()} countries, {len(fp69):,} obs")
print(f"140-country estimation sample: {fp140['iso3'].nunique()} countries, {len(fp140):,} obs")
print(f"Countries added: {len(countries_added)}")
print(f"Added: {sorted(countries_added)}")

# ── OECD membership ───────────────────────────────────────────────────────
OECD_MEMBERS = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA'
}

# ── Region classification ─────────────────────────────────────────────────
REGION_MAP = {
    'East Asia & Pacific': {
        'AUS', 'BRN', 'CHN', 'FJI', 'HKG', 'IDN', 'JPN', 'KHM', 'KIR', 'KOR',
        'LAO', 'MAC', 'MHL', 'MMR', 'MNG', 'MYS', 'NRU', 'NZL', 'PHL', 'PLW',
        'PNG', 'SGP', 'SLB', 'THA', 'TLS', 'TON', 'TUV', 'TWN', 'VNM', 'VUT', 'WSM',
        'BTN', 'LKA', 'NPL',  # South Asian neighbors often grouped with EAP in econ
    },
    'Europe & Central Asia': {
        'ALB', 'AND', 'ARM', 'AUT', 'AZE', 'BEL', 'BGR', 'BIH', 'BLR', 'CHE',
        'CYP', 'CZE', 'DEU', 'DNK', 'ESP', 'EST', 'FIN', 'FRA', 'GBR', 'GEO',
        'GRC', 'HRV', 'HUN', 'IRL', 'ISL', 'ITA', 'KAZ', 'KGZ', 'LTU', 'LUX',
        'LVA', 'MDA', 'MKD', 'MLT', 'MNE', 'NLD', 'NOR', 'POL', 'PRT', 'ROU',
        'RUS', 'SMR', 'SRB', 'SVK', 'SVN', 'SWE', 'TJK', 'TKM', 'TUR', 'UKR',
        'UZB', 'XKX'
    },
    'Latin America & Caribbean': {
        'ABW', 'ARG', 'ATG', 'BHS', 'BLZ', 'BOL', 'BRA', 'BRB', 'CHL', 'COL',
        'CRI', 'CUB', 'CUW', 'DMA', 'DOM', 'ECU', 'GRD', 'GTM', 'GUY', 'HND',
        'HTI', 'JAM', 'KNA', 'LCA', 'MEX', 'NIC', 'PAN', 'PER', 'PRY', 'SLV',
        'SUR', 'SXM', 'TCA', 'TTO', 'URY', 'VCT', 'VEN'
    },
    'Middle East & North Africa': {
        'ARE', 'BHR', 'DJI', 'DZA', 'EGY', 'IRN', 'IRQ', 'ISR', 'JOR', 'KWT',
        'LBN', 'LBY', 'MAR', 'OMN', 'PSE', 'QAT', 'SAU', 'SYR', 'TUN', 'YEM'
    },
    'South Asia': {
        'AFG', 'BGD', 'IND', 'MDV', 'PAK',
    },
    'Sub-Saharan Africa': {
        'AGO', 'BDI', 'BEN', 'BFA', 'BWA', 'CAF', 'CIV', 'CMR', 'COD', 'COG',
        'COM', 'CPV', 'ERI', 'ETH', 'GAB', 'GHA', 'GIN', 'GMB', 'GNB', 'GNQ',
        'KEN', 'LBR', 'LSO', 'MDG', 'MLI', 'MOZ', 'MRT', 'MUS', 'MWI', 'NAM',
        'NER', 'NGA', 'RWA', 'SEN', 'SLE', 'SOM', 'SSD', 'STP', 'SWZ', 'SYC',
        'TCD', 'TGO', 'TZA', 'UGA', 'ZAF', 'ZMB', 'ZWE', 'SDN',
    },
    'North America': {'CAN', 'USA'},
}

def get_region(iso3):
    for region, countries in REGION_MAP.items():
        if iso3 in countries:
            return region
    return 'Other'

# ── Income classification ─────────────────────────────────────────────────
gdp_all = panel.groupby('iso3')['gdp_pc_ppp'].median()

def get_income_group(iso3):
    val = gdp_all.get(iso3, np.nan)
    if pd.isna(val):
        return 'Unknown'
    elif val >= 30000:
        return 'High (>$30K)'
    elif val >= 10000:
        return 'Upper-middle ($10-30K)'
    elif val >= 3000:
        return 'Lower-middle ($3-10K)'
    else:
        return 'Low (<$3K)'


# ════════════════════════════════════════════════════════════════════════════
# TABLE 1: Panel composition by region
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 1: PANEL COMPOSITION BY REGION")
print("=" * 70)

regions = ['East Asia & Pacific', 'Europe & Central Asia', 'Latin America & Caribbean',
           'Middle East & North Africa', 'South Asia', 'Sub-Saharan Africa', 'North America']

rows = []
for region in regions:
    c69 = {c for c in countries_69 if get_region(c) == region}
    c140 = {c for c in countries_140 if get_region(c) == region}
    c_added = c140 - c69

    obs69 = len(fp69[fp69['iso3'].isin(c69)])
    obs140 = len(fp140[fp140['iso3'].isin(c140)])

    # CA estimation sample (notna + controls)
    ca69 = len(fp69[(fp69['iso3'].isin(c69)) & (fp69['ca_gdp'].notna())])
    ca140 = len(fp140[(fp140['iso3'].isin(c140)) & (fp140['ca_gdp'].notna())])

    oecd_69 = len(c69 & OECD_MEMBERS)
    oecd_140 = len(c140 & OECD_MEMBERS)

    rows.append({
        'Region': region,
        'Countries_69': len(c69),
        'Countries_140': len(c140),
        'Countries_added': len(c_added),
        'CA_obs_69': ca69,
        'CA_obs_140': ca140,
        'OECD_69': oecd_69,
        'OECD_140': oecd_140,
        'Pct_CA_obs_69': ca69 / fp69['ca_gdp'].notna().sum() * 100 if fp69['ca_gdp'].notna().sum() > 0 else 0,
        'Pct_CA_obs_140': ca140 / fp140['ca_gdp'].notna().sum() * 100 if fp140['ca_gdp'].notna().sum() > 0 else 0,
    })

# Also categorize "Other" (countries in the sample but not in any known region)
c69_other = countries_69 - set().union(*[{c for c in countries_69 if get_region(c) == r} for r in regions])
c140_other = countries_140 - set().union(*[{c for c in countries_140 if get_region(c) == r} for r in regions])
if c69_other or c140_other:
    rows.append({
        'Region': 'Other',
        'Countries_69': len(c69_other),
        'Countries_140': len(c140_other),
        'Countries_added': len(c140_other - c69_other),
        'CA_obs_69': len(fp69[(fp69['iso3'].isin(c69_other)) & (fp69['ca_gdp'].notna())]),
        'CA_obs_140': len(fp140[(fp140['iso3'].isin(c140_other)) & (fp140['ca_gdp'].notna())]),
        'OECD_69': len(c69_other & OECD_MEMBERS),
        'OECD_140': len(c140_other & OECD_MEMBERS),
        'Pct_CA_obs_69': 0, 'Pct_CA_obs_140': 0,
    })

# Totals
total_ca69 = fp69['ca_gdp'].notna().sum()
total_ca140 = fp140['ca_gdp'].notna().sum()
rows.append({
    'Region': 'Total',
    'Countries_69': len(countries_69),
    'Countries_140': len(countries_140),
    'Countries_added': len(countries_added),
    'CA_obs_69': total_ca69,
    'CA_obs_140': total_ca140,
    'OECD_69': len(countries_69 & OECD_MEMBERS),
    'OECD_140': len(countries_140 & OECD_MEMBERS),
    'Pct_CA_obs_69': 100.0,
    'Pct_CA_obs_140': 100.0,
})

comp_df = pd.DataFrame(rows)
comp_df.to_csv(TABLE_DIR / "table1_composition_by_region.csv", index=False)

# OECD share of CA observations
oecd_ca_69 = len(fp69[(fp69['iso3'].isin(OECD_MEMBERS)) & (fp69['ca_gdp'].notna())])
oecd_ca_140 = len(fp140[(fp140['iso3'].isin(OECD_MEMBERS)) & (fp140['ca_gdp'].notna())])
oecd_share_69 = oecd_ca_69 / total_ca69 * 100 if total_ca69 > 0 else 0
oecd_share_140 = oecd_ca_140 / total_ca140 * 100 if total_ca140 > 0 else 0

print(f"\nOECD CA observation share: {oecd_share_69:.1f}% (69c) → {oecd_share_140:.1f}% (140c)")

# Markdown
md = ["# Table 1: Panel Composition by Region\n"]
md.append(f"69-country sample = EBA(49) + SSA(20); 140-country = 69 + EU expansion(10) + Tier 1 expansion(~61)\n")
md.append("| Region | Countries (69) | Countries (140) | Added | CA obs (69) | CA obs (140) | OECD (69) | OECD (140) |")
md.append("|:---|---:|---:|---:|---:|---:|---:|---:|")
for _, r in comp_df.iterrows():
    md.append(f"| {r['Region']} | {r['Countries_69']} | {r['Countries_140']} | "
              f"{r['Countries_added']} | {r['CA_obs_69']:,} | {r['CA_obs_140']:,} | "
              f"{r['OECD_69']} | {r['OECD_140']} |")

md.append(f"\nOECD share of CA observations: **{oecd_share_69:.1f}%** (69-country) → **{oecd_share_140:.1f}%** (140-country)")

md_text = "\n".join(md)
(TABLE_DIR / "table1_composition_by_region.md").write_text(md_text)
print(md_text)


# ════════════════════════════════════════════════════════════════════════════
# TABLE 1B: Panel composition by income group
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 1B: PANEL COMPOSITION BY INCOME GROUP")
print("=" * 70)

income_groups = ['High (>$30K)', 'Upper-middle ($10-30K)', 'Lower-middle ($3-10K)', 'Low (<$3K)', 'Unknown']

rows_inc = []
for ig in income_groups:
    c69 = {c for c in countries_69 if get_income_group(c) == ig}
    c140 = {c for c in countries_140 if get_income_group(c) == ig}

    ca69 = fp69[(fp69['iso3'].isin(c69)) & (fp69['ca_gdp'].notna())]
    ca140 = fp140[(fp140['iso3'].isin(c140)) & (fp140['ca_gdp'].notna())]

    # Mean demographics
    z1_69 = fp69[fp69['iso3'].isin(c69)]['Z_1'].mean()
    z1_140 = fp140[fp140['iso3'].isin(c140)]['Z_1'].mean()

    # Mean KAOPEN
    ka69 = fp69[fp69['iso3'].isin(c69)]['kaopen'].mean()
    ka140 = fp140[fp140['iso3'].isin(c140)]['kaopen'].mean()

    rows_inc.append({
        'Income_group': ig,
        'Countries_69': len(c69),
        'Countries_140': len(c140),
        'Countries_added': len(c140 - c69),
        'CA_obs_69': len(ca69),
        'CA_obs_140': len(ca140),
        'Mean_Z1_69': z1_69,
        'Mean_Z1_140': z1_140,
        'Mean_KAOPEN_69': ka69,
        'Mean_KAOPEN_140': ka140,
        'OECD_69': len(c69 & OECD_MEMBERS),
        'OECD_140': len(c140 & OECD_MEMBERS),
    })

inc_df = pd.DataFrame(rows_inc)
inc_df.to_csv(TABLE_DIR / "table1b_composition_by_income.csv", index=False)

md2 = ["# Table 1b: Panel Composition by Income Group\n"]
md2.append("| Income Group | Countries (69) | Countries (140) | Added | CA obs (69) | CA obs (140) | Mean Z₁ (69) | Mean Z₁ (140) | Mean KAOPEN (69) | Mean KAOPEN (140) |")
md2.append("|:---|---:|---:|---:|---:|---:|---:|---:|---:|---:|")
for _, r in inc_df.iterrows():
    z69 = f"{r['Mean_Z1_69']:.2f}" if pd.notna(r['Mean_Z1_69']) else "—"
    z140 = f"{r['Mean_Z1_140']:.2f}" if pd.notna(r['Mean_Z1_140']) else "—"
    ka69 = f"{r['Mean_KAOPEN_69']:.2f}" if pd.notna(r['Mean_KAOPEN_69']) else "—"
    ka140 = f"{r['Mean_KAOPEN_140']:.2f}" if pd.notna(r['Mean_KAOPEN_140']) else "—"
    md2.append(f"| {r['Income_group']} | {r['Countries_69']} | {r['Countries_140']} | "
               f"{r['Countries_added']} | {r['CA_obs_69']:,} | {r['CA_obs_140']:,} | "
               f"{z69} | {z140} | {ka69} | {ka140} |")

md2_text = "\n".join(md2)
(TABLE_DIR / "table1b_composition_by_income.md").write_text(md2_text)
print(md2_text)


# ════════════════════════════════════════════════════════════════════════════
# TABLE 1C: KAOPEN distribution comparison
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 1C: KAOPEN DISTRIBUTION")
print("=" * 70)

kaopen_max = panel['kaopen'].max()
ka69 = fp69['kaopen'].dropna()
ka140 = fp140['kaopen'].dropna()

ceiling_69 = (ka69 >= kaopen_max - 0.01).mean() * 100
ceiling_140 = (ka140 >= kaopen_max - 0.01).mean() * 100

def between_within_var(df, var, entity='iso3'):
    valid = df[[entity, var]].dropna()
    total_var = valid[var].var()
    means = valid.groupby(entity)[var].mean()
    between_var = means.var()
    within_var = total_var - between_var
    return between_var, within_var, total_var

bw69 = between_within_var(fp69, 'kaopen')
bw140 = between_within_var(fp140, 'kaopen')

kaopen_stats = {
    'Statistic': ['Mean', 'Std Dev', 'Min', 'Max', 'At ceiling (%)',
                  'Between variance share (%)', 'Within variance share (%)'],
    '69-country': [ka69.mean(), ka69.std(), ka69.min(), ka69.max(), ceiling_69,
                   bw69[0]/bw69[2]*100, bw69[1]/bw69[2]*100],
    '140-country': [ka140.mean(), ka140.std(), ka140.min(), ka140.max(), ceiling_140,
                    bw140[0]/bw140[2]*100, bw140[1]/bw140[2]*100],
}
kaopen_df = pd.DataFrame(kaopen_stats)
kaopen_df.to_csv(TABLE_DIR / "table1c_kaopen_distribution.csv", index=False)

md3 = ["# Table 1c: KAOPEN Distribution Comparison\n"]
md3.append("| Statistic | 69-country | 140-country |")
md3.append("|:---|---:|---:|")
for _, r in kaopen_df.iterrows():
    v69 = f"{r['69-country']:.2f}"
    v140 = f"{r['140-country']:.2f}"
    md3.append(f"| {r['Statistic']} | {v69} | {v140} |")
md3.append(f"\nKAOPEN ceiling-bunching: **{ceiling_69:.1f}%** → **{ceiling_140:.1f}%** of observations.")
md3.append(f"Between-country variance share: **{bw69[0]/bw69[2]*100:.1f}%** → **{bw140[0]/bw140[2]*100:.1f}%**.")

md3_text = "\n".join(md3)
(TABLE_DIR / "table1c_kaopen_distribution.md").write_text(md3_text)
print(md3_text)


# ════════════════════════════════════════════════════════════════════════════
# TABLE 1D: Variable coverage comparison
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 1D: VARIABLE COVERAGE")
print("=" * 70)

key_vars = ['ca_gdp', 'Z_1', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp',
            'gdp_pc_ppp', 'govt_bond_10y', 'short_rate_3m', 'lending_rate',
            'life_expectancy', 'old_dep', 'youth_dep', 'gross_investment_gdp',
            'trade_openness', 'pension_spending_gdp']

cov_rows = []
for v in key_vars:
    if v in fp69.columns and v in fp140.columns:
        n69 = fp69[v].notna().sum()
        c69 = fp69[fp69[v].notna()]['iso3'].nunique()
        n140 = fp140[v].notna().sum()
        c140 = fp140[fp140[v].notna()]['iso3'].nunique()
        cov_rows.append({
            'Variable': v,
            'Obs_69': n69, 'Countries_69': c69,
            'Obs_140': n140, 'Countries_140': c140,
            'Obs_ratio': n140 / n69 if n69 > 0 else np.nan,
        })

cov_df = pd.DataFrame(cov_rows)
cov_df.to_csv(TABLE_DIR / "table1d_variable_coverage.csv", index=False)

md4 = ["# Table 1d: Variable Coverage Comparison\n"]
md4.append("| Variable | Obs (69c) | Countries (69c) | Obs (140c) | Countries (140c) | Expansion ratio |")
md4.append("|:---|---:|---:|---:|---:|---:|")
for _, r in cov_df.iterrows():
    md4.append(f"| {r['Variable']} | {r['Obs_69']:,} | {r['Countries_69']} | "
               f"{r['Obs_140']:,} | {r['Countries_140']} | {r['Obs_ratio']:.1f}× |")

md4_text = "\n".join(md4)
(TABLE_DIR / "table1d_variable_coverage.md").write_text(md4_text)
print(md4_text)


# ════════════════════════════════════════════════════════════════════════════
# Summary
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"Estimation sample: {len(countries_69)} → {len(countries_140)} (+{len(countries_added)})")
print(f"CA observations: {total_ca69:,} → {total_ca140:,}")
print(f"OECD share of CA obs: {oecd_share_69:.1f}% → {oecd_share_140:.1f}%")
print(f"KAOPEN ceiling-bunching: {ceiling_69:.1f}% → {ceiling_140:.1f}%")
print(f"KAOPEN between-variance share: {bw69[0]/bw69[2]*100:.1f}% → {bw140[0]/bw140[2]*100:.1f}%")
print(f"\nSaved 4 tables to {TABLE_DIR}")
print("Phase 1 complete.")
